Load the necessary libraries and source the utility functions.

library(Survinng)
library(ggplot2)
library(cowplot)
library(dplyr)
library(tidyr)
library(SurvMetrics)
library(survival)
library(simsurv)
library(survminer)
library(torch)
library(viridis)
library(here)

# Load utility functions
source(here("utils/utils_nn_training.R"))
source(here("utils/utils_plotting.R"))


# Set figure path
fig_path <- here::here("figures_paper")
if (!file.exists(fig_path)) dir.create(fig_path)
fig <- function(x) here::here(fig_path, x)

Time-dependent effects

Generate the data

Simulation setting: - \(10,000\) samples (\(9,500\) for training, \(500\) for testing) - \(X_1 \sim \mathcal{N}(0,1)\) has a time-dependent effect: first negative, then positive on hazard (vice versa on survival) - \(X_2 \sim \mathcal{N}(0,1)\) has a positive effect on the hazard -> negative effect on survival - \(X_3 \sim \mathcal{U}(0,1)\) has a stronger negative effect on the hazard -> positive effect on survival - \(X_3 \sim \mathcal{U}(-1,1)\) has no effect

set.seed(42)

# Simulate data
n <- 10000
x <- data.frame(x1 = runif(n, 0, 1), x2 = rnorm(n), x3 = rnorm(n), x4 = runif(n, -1, 1))

simdat <- simsurv(dist = "weibull", lambdas = 0.1, gammas = 1.5, betas = c(x1 = -3, x2 = 1.7, x3 = -2.4), x = x, tde = c(x1 = 6), tdefunction = "log", maxt = 7)
y <- simdat[, -1]
colnames(y)[1] <- "time"
dat <- cbind(y, x)

# Train/test
idx <- sample(n, 9500)
train <- dat[idx, ]
test <- dat[-idx, ]

Fit the models

ext_deephit <- fit_model("DeepHit", train, test)
ext_coxtime <- fit_model("CoxTime", train, test)
ext_deepsurv <- fit_model("DeepSurv", train, test)

Create Explainer

exp_deephit <- Survinng::explain(ext_deephit[[1]], data = test)
exp_coxtime <- Survinng::explain(ext_coxtime[[1]], data = test)
exp_deepsurv <- Survinng::explain(ext_deepsurv[[1]], data = test)

Performance Measures

# Define a function to reshape data into a wide matrix format
prepare_matrix <- function(data, id_col = "id", time_col = "time", pred_col = "pred") {
  wide_data <- data %>%
    pivot_wider(names_from = {{time_col}}, values_from = {{pred_col}})
  
  # Convert the specified column to row names and remove it from the data
  wide_data <- as.data.frame(wide_data)
  rownames(wide_data) <- wide_data[[id_col]]
  wide_data <- wide_data[, -which(names(wide_data) == id_col)]
  
  # Convert to a matrix
  as.matrix(wide_data)
}

# Prepare matrices
matrix_coxtime <- prepare_matrix(ext_coxtime$pred)
matrix_deepsurv <- prepare_matrix(ext_deepsurv$pred)
matrix_deephit <- prepare_matrix(ext_deephit$pred)[,-1]

# Define survival object
surv_obj <- Surv(test$time, test$status)

# Define time indices and sampled time
t_interest <- sort(unique(ext_coxtime$pred$time))
num_samples <- 100
indices <- round(seq(1, length(t_interest), length.out = num_samples))
sampled_t <- t_interest[indices]
deephit_t <- sort(unique(ext_deephit$pred$time))[-1]

# Sample matrices
sampled_matrix_coxtime <- matrix_coxtime[, indices]
sampled_matrix_deepsurv <- matrix_deepsurv[, indices]

# Calculate Brier scores in a single step
calculate_brier <- function(matrix, times, surv_obj) {
  sapply(1:ncol(matrix), function(i) Brier(surv_obj, pre_sp = matrix[, i], times[i]))
}

metrics_coxtime <- calculate_brier(sampled_matrix_coxtime, sampled_t, surv_obj)
metrics_deepsurv <- calculate_brier(sampled_matrix_deepsurv, sampled_t, surv_obj)
metrics_deephit <- calculate_brier(matrix_deephit, deephit_t, surv_obj)

# Combine results into a single data frame for plotting
combine_results <- function(metrics, times, model_name) {
  data.frame(time = times, BS = metrics, model = model_name)
}

data_coxtime <- combine_results(metrics_coxtime, sampled_t, "CoxTime")
data_deepsurv <- combine_results(metrics_deepsurv, sampled_t, "DeepSurv")
data_deephit <- combine_results(metrics_deephit, deephit_t, "DeepHit")
data_BS <- rbind(data_coxtime, data_deepsurv, data_deephit)

# Plot Brier scores
colorblind_palette <- c("CoxTime" = "#E69F00", "DeepSurv" = "#56B4E9", "DeepHit" = "#009E73")

brier_plot_td <- ggplot() +
  geom_line(data = data_BS, aes(x = time, y = BS, color = model, linetype = model)) +
  geom_rug(data = test, aes(x = time), sides = "bl", linewidth = 0.5, alpha = 0.5) +
  scale_color_manual(values = colorblind_palette) +  # Apply custom colors
  scale_linetype_manual(values = c("CoxTime" = "solid", "DeepSurv" = "dashed", "DeepHit" = "dotted")) + 
  labs(title = "", x = "Time t", y = "Brier Score", color = NULL, linetype = NULL) +
  theme_minimal(base_size = 17) +
  theme(legend.position = "bottom")
brier_plot_td


# Save plot
ggsave(fig("sim_td_brier_plot.pdf"), plot = brier_plot_td, width = 7, height = 5)

# Calculate C-index and IBS for each model
calculate_cindex <- function(matrix, surv_obj, index) {
  Cindex(surv_obj, predicted = matrix[, index])
}

calculate_ibs <- function(matrix, times, surv_obj) {
  IBS(surv_obj, sp_matrix = matrix, times)
}

C_coxtime <- calculate_cindex(sampled_matrix_coxtime, surv_obj, 50)
C_deepsurv <- calculate_cindex(sampled_matrix_deepsurv, surv_obj, 50)
C_deephit <- calculate_cindex(matrix_deephit, surv_obj, 15)

IBS_coxtime <- calculate_ibs(sampled_matrix_coxtime, sampled_t, surv_obj)
IBS_deepsurv <- calculate_ibs(sampled_matrix_deepsurv, sampled_t, surv_obj)
IBS_deephit <- calculate_ibs(matrix_deephit[,-1], deephit_t[-1], surv_obj)

# Display results
res <- data.frame(
  model = c("CoxTime", "DeepSurv", "DeepHit"),
  C_index = c(C_coxtime, C_deepsurv, C_deephit),
  IBS = c(IBS_coxtime, IBS_deepsurv, IBS_deephit)
)
saveRDS(res, fig("sim_td_performance.rds"))
res
#>      model  C_index      IBS
#> 1  CoxTime 0.845782 0.058368
#> 2 DeepSurv 0.859559 0.060464
#> 3  DeepHit 0.806480 0.095961

Kaplan-Meier Survival Curves

# Categorize `x1` into bins (e.g., low, medium, high)
dat$x1_group <- cut(dat$x1, 
                    breaks = quantile(dat$x1, probs = c(0, 0.5, 1)), 
                    labels = c("Low", "High"), 
                    include.lowest = TRUE)

# Create a Surv object
surv_obj <- Surv(dat$time, dat$status)

# Fit Kaplan-Meier survival curves stratified by `x1_group`
km_fit <- survfit(surv_obj ~ x1_group, data = dat)

# Plot the KM curves
km_plot <- ggsurvplot(km_fit, 
           data = dat,
           xlab = "Time t",
           ylab = "Survival Probability",
           legend.title = "x1 Group",
           palette = c("#377EB8", "#E69F00"),  
           title = "") 
km_plot$plot <- km_plot$plot + 
                theme_minimal(base_size = 17) +
                theme(legend.position = "bottom") +
  geom_rug(data = test, aes(x = time), sides = "bl", linewidth = 0.5, alpha = 0.5, inherit.aes = FALSE)
km_plot


# Save plot
ggsave(fig("sim_td_km_plot.pdf"), plot = km_plot$plot, width = 7, height = 5)

Survival Prediction

# Print instances of interest
td_ids <- c(79, 428)
print(test[td_ids, ])
#>          time status        x1         x2         x3         x4
#> 1535 4.306154      1 0.1934749 -0.3069111 -0.1475012  0.6370386
#> 8202 1.417212      1 0.7526954 -0.3781274 -1.1500862 -0.6418508

# Compute Vanilla Gradient
grad_cox <- surv_grad(exp_coxtime, target = "survival", instance = td_ids)
#> [Survinng] Batch size is smaller than the number of timepoints. Setting batch size (50) to the number of timepoints (74).
grad_deephit <- surv_grad(exp_deephit, target = "survival", instance = td_ids)
grad_deepsurv <- surv_grad(exp_deepsurv, target = "survival", instance = td_ids)

# Plot survival predictions
surv_plot <- cowplot::plot_grid(
  plot_surv_pred(grad_cox) ,
  plot_surv_pred(grad_deephit),
  plot_surv_pred(grad_deepsurv),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"),
  label_x = 0.03,      
  label_size = 14) 
surv_plot


# Save plot
ggsave(fig("sim_td_surv_plot.pdf"), plot = surv_plot, width = 8, height = 14)

Explainable AI

Gradient (Sensitivity)

# Plot attributions
grad_plot <- cowplot::plot_grid(
  plot_attribution(grad_cox, label = "Grad(t)") ,
  plot_attribution(grad_deephit, label = "Grad(t)"),
  plot_attribution(grad_deepsurv, label = "Grad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
grad_plot


# Save plot
ggsave(fig("sim_td_grad_plot.pdf"), plot = grad_plot, width = 10, height = 14)
# Plot attributions
grad_plot_norm <- cowplot::plot_grid(
  plot_attribution(grad_cox, normalize = TRUE, label = "Grad(t)"),
  plot_attribution(grad_deephit, normalize = TRUE, label = "Grad(t)"),
  plot_attribution(grad_deepsurv, normalize = TRUE, label = "Grad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
grad_plot_norm


# Save plot
ggsave(fig("sim_td_grad_plot_norm.pdf"), plot = grad_plot_norm, width = 10, height = 14)

SmoothGrad (Sensitivity)

# Compute SmoothGrad
sg_cox <- surv_smoothgrad(exp_coxtime, target = "survival", instance = td_ids, 
                          n = 50, noise_level = 0.1)
#> [Survinng] Batch size is smaller than the number of timepoints. Setting batch size (50) to the number of timepoints (74).
sg_deephit <- surv_smoothgrad(exp_deephit, target = "survival", instance = td_ids, 
                              n = 50, noise_level = 0.1)
sg_deepsurv <- surv_smoothgrad(exp_deepsurv, target = "survival", instance = td_ids, 
                               n = 50, noise_level = 0.1)

# Plot attributions
smoothgrad_plot <- cowplot::plot_grid(
  plot_attribution(sg_cox, label = "SG(t)"), 
  plot_attribution(sg_deephit, label = "SG(t)"), 
  plot_attribution(sg_deepsurv, label = "SG(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgrad_plot


# Save plot
ggsave(fig("sim_td_smoothgrad_plot.pdf"), plot = smoothgrad_plot, width = 10, height = 14)
# Plot attributions
smoothgrad_plot_comp <- cowplot::plot_grid(
  plot_attribution(sg_cox, normalize = TRUE, label = "SG(t)"), 
  plot_attribution(sg_deephit, normalize = TRUE, label = "SG(t)"), 
  plot_attribution(sg_deepsurv, normalize = TRUE, label = "SG(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgrad_plot_comp


# Save plot
ggsave(fig("sim_td_smoothgrad_plot_norm.pdf"), plot = smoothgrad_plot_comp, width = 10, height = 14)

Gradient x Input

# Compute GradientxInput
gradin_cox <- surv_grad(exp_coxtime, instance = td_ids, times_input = TRUE)
#> [Survinng] Batch size is smaller than the number of timepoints. Setting batch size (50) to the number of timepoints (74).
gradin_deephit <- surv_grad(exp_deephit, instance = td_ids, times_input = TRUE)
gradin_deepsurv <- surv_grad(exp_deepsurv, instance = td_ids, times_input = TRUE)

# Plot attributions
gradin_plot <- cowplot::plot_grid(
  plot_attribution(gradin_cox, label = "GxI(t)"), 
  plot_attribution(gradin_deephit, label = "GxI(t)"), 
  plot_attribution(gradin_deepsurv, label = "GxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gradin_plot


# Save plot
ggsave(fig("sim_td_gradin_plot.pdf"), plot = gradin_plot, width = 10, height = 14)
# Plot attributions
gradin_plot_norm <- cowplot::plot_grid(
  plot_attribution(gradin_cox, normalize = TRUE, label = "GxI(t)"), 
  plot_attribution(gradin_deephit, normalize = TRUE, label = "GxI(t)"), 
  plot_attribution(gradin_deepsurv, normalize = TRUE, label = "GxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gradin_plot_norm


# Save plot
ggsave(fig("sim_td_gradin_plot_norm.pdf"), plot = gradin_plot_norm, width = 10, height = 14)

SmoothGrad x Input

# Compute SmoothGradxInput
sgin_cox <- surv_smoothgrad(exp_coxtime, instance = td_ids, n = 50, noise_level = 0.3,
                          times_input = TRUE)
#> [Survinng] Batch size is smaller than the number of timepoints. Setting batch size (50) to the number of timepoints (74).
sgin_deephit <- surv_smoothgrad(exp_deephit, instance = td_ids, n = 50, noise_level = 0.3,
                              times_input = TRUE)
sgin_deepsurv <- surv_smoothgrad(exp_deepsurv, instance = td_ids, n = 50, noise_level = 0.3,
                               times_input = TRUE)

# Plot attributions
smoothgradin_plot <- cowplot::plot_grid(
  plot_attribution(sgin_cox, label = "SGxI(t)"), 
  plot_attribution(sgin_deephit, label = "SGxI(t)"), 
  plot_attribution(sgin_deepsurv, label = "SGxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgradin_plot


# Save plot
ggsave(fig("sim_td_smoothgradin_plot.pdf"), plot = smoothgradin_plot, width = 10, height = 14)
# Plot attributions
smoothgradin_plot_norm <- cowplot::plot_grid(
  plot_attribution(sgin_cox, normalize = TRUE, label = "SGxI(t)"), 
  plot_attribution(sgin_deephit, normalize = TRUE, label = "SGxI(t)"), 
  plot_attribution(sgin_deepsurv, normalize = TRUE, label = "SGxI(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
smoothgradin_plot_norm


# Save plot
ggsave(fig("sim_td_smoothgradin_plot_norm.pdf"), plot = smoothgradin_plot_norm, width = 10, height = 14)

IntegratedGradient

Zero baseline (should be proportional to Gradient x Input)

# Copute IntegratedGradient with 0 baseline
x_ref <- matrix(c(0,0,0,0), nrow = 1)
ig0_cox <- surv_intgrad(exp_coxtime, instance = td_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
ig0_deephit <- surv_intgrad(exp_deephit, instance = td_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
ig0_deepsurv <- surv_intgrad(exp_deepsurv, instance = td_ids, n = 50, x_ref = x_ref, batch_size = 1e6)

# Plot attributions
intgrad0_plot <- cowplot::plot_grid(
  plot_attribution(ig0_cox, label = "IntGrad(t)"), 
  plot_attribution(ig0_deephit, label = "IntGrad(t)"), 
  plot_attribution(ig0_deepsurv, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot 


# Save plot
ggsave(fig("sim_td_intgrad0_plot.pdf"), plot = intgrad0_plot, width = 10, height = 14)
# Plot attributions
intgrad0_plot_comp <- cowplot::plot_grid(
  plot_attribution(ig0_cox, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(ig0_deephit, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(ig0_deepsurv, add_comp = TRUE, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot_comp 


# Save plot
ggsave(fig("sim_td_intgrad0_plot_comp.pdf"), plot = intgrad0_plot_comp, width = 10, height = 14)
# Plot contributions
intgrad0_plot_contr <- cowplot::plot_grid(
  plot_contribution(ig0_cox, scale = 0.9, label = "IntGrad(t)"), 
  plot_contribution(ig0_deephit, scale = 0.9, label = "IntGrad(t)"), 
  plot_contribution(ig0_deepsurv, scale = 0.9, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot_contr 


# Save plot
ggsave(fig("sim_td_intgrad0_plot_contr.pdf"), plot = intgrad0_plot_contr, width = 10, height = 14)
# Plot force
intgrad0_plot_force <- cowplot::plot_grid(
  plot_force(ig0_cox, zero_feature = "x4", upper_distance = 0, lower_distance = 0, lower_distance_x1 = 0.05, intgrad0_td_cox = TRUE, label = "IntGrad(t)"), 
  plot_force(ig0_deephit, zero_feature = "x4", upper_distance = 0, lower_distance = 0, lower_distance_x1 = 0.02, label = "IntGrad(t)"),  
  plot_force(ig0_deepsurv, zero_feature = "x4", upper_distance = 0, lower_distance = 0, intgrad0_td_deepsurv = TRUE, label = "IntGrad(t)"), 
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgrad0_plot_force 


# Save plot
ggsave(fig("sim_td_intgrad0_plot_force.pdf"), plot = intgrad0_plot_force, width = 10, height = 14)

Mean baseline

# Compute IntegratedGradient with mean baseline
x_ref <- NULL # default: feature-wise mean
igm_cox <- surv_intgrad(exp_coxtime, instance = td_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
igm_deephit <- surv_intgrad(exp_deephit, instance = td_ids, n = 50, x_ref = x_ref, batch_size = 1e6)
igm_deepsurv <- surv_intgrad(exp_deepsurv, instance = td_ids, n = 50, x_ref = x_ref, batch_size = 1e6)

# Plot attributions
intgradmean_plot <- cowplot::plot_grid(
  plot_attribution(igm_cox, label = "IntGrad(t)"), 
  plot_attribution(igm_deephit, label = "IntGrad(t)"), 
  plot_attribution(igm_deepsurv, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradmean_plot


# Save plot
ggsave(fig("sim_td_intgradmean_plot.pdf"), plot = intgradmean_plot, width = 10, height = 14)
# Plot attributions
intgradmean_plot_comp <- cowplot::plot_grid(
  plot_attribution(igm_cox, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(igm_deephit, add_comp = TRUE, label = "IntGrad(t)"), 
  plot_attribution(igm_deepsurv, add_comp = TRUE, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradmean_plot_comp


# Save plot
ggsave(fig("sim_td_intgradmean_plot_comp.pdf"), plot = intgradmean_plot_comp, width = 10, height = 14)
# Plot contributions
intgradmean_plot_contr <- cowplot::plot_grid(
  plot_contribution(igm_cox, scale = 0.9, label = "IntGrad(t)"), 
  plot_contribution(igm_deephit, scale = 0.9, label = "IntGrad(t)"), 
  plot_contribution(igm_deepsurv, scale = 0.9, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradmean_plot_contr


# Save plot
ggsave(fig("sim_td_intgradmean_plot_contr.pdf"), plot = intgradmean_plot_contr, width = 10, height = 14)
# Plot force
intgradmean_plot_force <- cowplot::plot_grid(
  plot_force(igm_cox, zero_feature = "x4", upper_distance = 0, lower_distance = 0, intgradmean_td_cox = TRUE, label = "IntGrad(t)"), 
  plot_force(igm_deephit, zero_feature = "x4", upper_distance = 0, lower_distance = 0, intgradmean_td_deephit = TRUE, label = "IntGrad(t)"), 
  plot_force(igm_deepsurv, zero_feature = "x4", upper_distance = 0, lower_distance = 0, intgradmean_td_deepsurv = TRUE, label = "IntGrad(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
intgradmean_plot_force


# Save plot
ggsave(fig("sim_td_intgradmean_plot_force.pdf"), plot = intgradmean_plot_force, width = 10, height = 14)

GradShap

# Compute GradShap
gshap_cox <- surv_gradSHAP(exp_coxtime, instance = td_ids, n = 50, num_samples = 100, 
                           batch_size = 1e6)
gshap_deephit <- surv_gradSHAP(exp_deephit, instance = td_ids, n = 50, num_samples = 100, 
                           batch_size = 1e6)
gshap_deepsurv <- surv_gradSHAP(exp_deepsurv, instance = td_ids, n = 50, num_samples = 100, 
                           batch_size = 1e6)

# Plot attributions
gshap_plot <- cowplot::plot_grid(
  plot_attribution(gshap_cox, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deephit, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deepsurv, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot 


# Save plot
ggsave(fig("sim_td_gshap_plot.pdf"), plot = gshap_plot, width = 10, height = 14)
# Plot attributions
gshap_plot_comp <- cowplot::plot_grid(
  plot_attribution(gshap_cox, add_comp = TRUE, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deephit, add_comp = TRUE, label = "GradSHAP(t)"), 
  plot_attribution(gshap_deepsurv, add_comp = TRUE, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot_comp 


# Save plot
ggsave(fig("sim_td_gshap_plot_comp.pdf"), plot = gshap_plot_comp, width = 10, height = 14)
# Plot contributions
gshap_plot_contr <- cowplot::plot_grid(
  plot_contribution(gshap_cox, scale = 0.9, label = "GradSHAP(t)"), 
  plot_contribution(gshap_deephit, scale = 0.9, label = "GradSHAP(t)"), 
  plot_contribution(gshap_deepsurv, scale = 0.9, label = "GradSHAP(t)"), 
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot_contr


# Save plot
ggsave(fig("sim_td_gshap_plot_contr.pdf"), plot = gshap_plot_contr, width = 10, height = 14)
# Plot force
gshap_plot_force <- cowplot::plot_grid(
  plot_force(gshap_cox, upper_distance = 0, lower_distance = 0, zero_feature = "x4", gradshap_td_cox = TRUE, label = "GradSHAP(t)"), 
  plot_force(gshap_deephit, upper_distance = 0, lower_distance = 0, zero_feature = "x4", gradshap_td_deephit = TRUE, label = "GradSHAP(t)"), 
  plot_force(gshap_deepsurv, upper_distance = 0, lower_distance = 0, zero_feature = "x4", gradshap_td_deepsurv = TRUE, label = "GradSHAP(t)"),
  nrow = 3, labels = c("CoxTime", "DeepHit", "DeepSurv"))
gshap_plot_force


# Save plot
ggsave(fig("sim_td_gshap_plot_force.pdf"), plot = gshap_plot_force, width = 10, height = 14)

Main paper plots

# Plot attributions
grad_deephit_plot <- plot_attribution(grad_deephit, label = "Grad(t)")
grad_deephit_plot


# Save plot
ggsave(fig("main_grad_deephit_td.pdf"), plot = grad_deephit_plot, width = 8, height = 4.5)
# Plot force
gshap_coxtime_plot <- cowplot::plot_grid(
  plot_attribution(gshap_cox, add_comp = TRUE, label = "GradSHAP(t)"), 
  plot_contribution(gshap_cox, scale = 0.9, label = "GradSHAP(t)"),
  plot_force(gshap_cox, upper_distance = 0, lower_distance = 0, zero_feature = "x4", gradshap_td_cox = TRUE, label = "GradSHAP(t)"), 
  nrow = 3, labels = c("CoxTime", "CoxTime", "CoxTime"))
gshap_coxtime_plot


# Save plot
ggsave(fig("main_gshap_coxtime_td.pdf"), plot = gshap_coxtime_plot, width = 10, height = 14)